package comet

import (
	acsclient "acs/comet/client"
	"acs/comet/config"
	"acs/comet/proto"
	"acs/pbmodel"
	"encoding/binary"
	"errors"
	"fmt"
	"io"
	"net"
	"runtime/debug"
	"sync"
	"sync/atomic"
	"time"

	log "github.com/cihub/seelog"
	pb "github.com/golang/protobuf/proto"
)

const (
	Second = int64(time.Second)
)

var (
	exitWG     *sync.WaitGroup
	clientList *acsclient.ClientList
	logger     = log.Default
	setMutex   = new(sync.Mutex)
	Conf       *config.Config
)

var (
	ErrParsePacketLength               = errors.New("Packet length parse error")
	ErrParsePacketTransactionIdAndType = errors.New("Packet transaction ID and type parse error")
	ErrPacketLenExcess                 error
	ErrPacketInvalidType               = errors.New("Invalid packet type.")
)

var appConnectionCounter int32

type PacketResult struct {
	// Tsid 事务ID
	Tsid    uint16
	Packet  *proto.Packet
	err     error
	errCode int32
}

// SetLogger sets the logger for taskprocess package.
// seelog.Default it the default logger.
func SetLogger(newLogger log.LoggerInterface) {
	setMutex.Lock()
	defer setMutex.Unlock()
	logger = newLogger
}

func StartAppConfigServer(addr string, clients *acsclient.ClientList, exitWaitGroup *sync.WaitGroup) error {
	Conf = config.Conf
	ErrPacketLenExcess = fmt.Errorf("packet length exceeded %v", Conf.MaxProtoSize)
	clientList = clients
	exitWG = exitWaitGroup
	exitWG.Add(1)
	go func(addr string) {
		defer func() {
			exitWG.Done()
			logger.Flush()
		}()
		tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
		if err != nil {
			logger.Error(err)
			panic(err)
		}
		logger.Infof("starting tcp at %v", addr)
		tl, err := net.ListenTCP("tcp", tcpAddr)
		if err != nil {
			logger.Error(err)
			panic(err)
		}
		defer tl.Close()
		connChan := make(chan *net.TCPConn)
		go acceptTCP(tl, connChan)
		for {
			select {
			case <-exitService:
				return
			case conn := <-connChan:
				if conn == nil {
					return
				}
				tc := atomic.LoadInt32(&appConnectionCounter)
				if tc >= Conf.MaxAppConnections {
					logger.Warnf("maxium client connections [%v] recached[%v]...reject %v.", Conf.MaxAppConnections, tc, conn.RemoteAddr())
					rejectMoreConnections(conn)
					conn.Close()
					continue
				}
				go HandleTCPConn(conn)
			}
		}
	}(addr)
	return nil
}

func rejectMoreConnections(conn net.Conn) {
	errMsg, err := pb.Marshal(&pbmodel.ProtoErr{Code: pb.Int(1001), Msg: pb.String(fmt.Sprintf("Maxium connections reached[%v].", Conf.MaxAppConnections))})
	if err != nil {
		logger.Warnf("error message marshal error: %v", err)
		return
	}
	conn.SetDeadline(time.Now().Add(time.Second * 1))
	_, err = conn.Write(proto.NewRequest(0, "conn_reject", errMsg, Conf.DataEncryptKey, Conf.DataEncryptIV).Encode())
	if err != nil {
		logger.Warnf("reject app connection error: %v", err)
	}
}

func acceptTCP(tl *net.TCPListener, ch chan *net.TCPConn) {
	for {
		conn, err := tl.AcceptTCP()
		if err != nil {
			select {
			case <-exitService:
				return
			default:
				logger.Warn(err)
			}
			ch <- nil
		} else {
			ch <- conn
		}
	}
}

// HandleTCPConn 处理可兑换连接。 遇到所有连接或协议相关的错误,都应当关闭连接.
func HandleTCPConn(conn *net.TCPConn) {
	atomic.AddInt32(&appConnectionCounter, 1)
	exitWG.Add(1)
	var client *acsclient.Client
	defer func() {
		atomic.AddInt32(&appConnectionCounter, -1)
		err := recover()
		if err != nil {
			logger.Errorf("client handler routine exited unexpectedly - RemoteAddr: %v, LocalAddr: %v [ %v ]: %s", conn.RemoteAddr(), conn.LocalAddr(), err, debug.Stack())
		}
		exitWG.Done()
	}()
	logger.Debugf("Accepted connection from %v", conn.RemoteAddr().String())
	err := conn.SetKeepAlive(true)
	if err != nil {
		logger.Warn(err)
		conn.Close()
		return
	}
	client = acsclient.NewClient(&acsclient.ClientConfig{
		Conn:              conn,
		HandshakeTimeout:  Conf.ClientHandshakeTimeout,
		ExitWaitGroup:     exitWG,
		MaxReqRespTimeout: Conf.MaxReqRespTimeout,
		DataEncryptKey:    Conf.DataEncryptKey,
		DataEncryptIV:     Conf.DataEncryptIV,
	})
	ele, _ := clientList.AddClient(conn.RemoteAddr().String(), client)
	packetResultChan := make(chan PacketResult, 1)
	emptyResult  := PacketResult{}
	readerProcessingLock := new(sync.Mutex)
	readerExitChan := make(chan bool)
	go readPacketFromClient(client, packetResultChan, readerProcessingLock, readerExitChan)

	// 设置对端空闲超时时间.
	tIdle := time.Millisecond * time.Duration(Conf.HeartBeatTimeout)
	var clientTimeout <-chan time.Time

LOOPHDL:
	for {
		clientTimeout = time.After(tIdle)
		var packetResult PacketResult
		select {
		// 对端响应超时或长时间无通信
		case <-clientTimeout:
			logger.Warnf("Client idled or response timedout for %v, closing...", tIdle)
			break LOOPHDL

		case packetResult = <-packetResultChan:

		// 服务正常退出
		case <-exitService:
			break LOOPHDL

		case err = <-client.ConnErr:
			logger.Warnf("client connection broken: %v", err)
			// 防止业务处理时返回非连接相关的错误.
			break LOOPHDL
		}

		// 数据读读取或解析出错时,由于协议上的限制,无法再正常通讯, 应当退出连接上的服务。
		if packetResult.err != nil {
			client.HandleProtoError(packetResult.Tsid, packetResult.errCode, packetResult.err)
			logger.Warn(packetResult.err)
			break LOOPHDL
		}
		// reader 已经退出
		if packetResult != emptyResult {
			processPacketResult(client, &packetResult)
		} else {
			break LOOPHDL
		}

	}

	close(readerExitChan)
	// 处理尚未处理的包
	readerProcessingLock.Lock()
	defer readerProcessingLock.Unlock()
	var readerExitTimeout <-chan time.Time
WaitReader:
	for {
		readerExitTimeout = time.After(time.Millisecond * 20)
		select {
		case <-readerExitTimeout:
			break WaitReader
		case packetResult := <-packetResultChan:
			if packetResult == emptyResult{
				break WaitReader
			}
			if packetResult.err != nil {
				client.HandleProtoError(packetResult.Tsid, packetResult.errCode, packetResult.err)
				logger.Warn(packetResult.err)
			} else {
				processPacketResult(client, &packetResult)
			}
		}
	}

	clientList.RemoveClient(conn.RemoteAddr().String(), ele)
	client.Close()
	logger.Infof("remove connection %v", conn.RemoteAddr().String())
}

func readPacketFromClient(client *acsclient.Client, packetResultChan chan PacketResult, readerProcessingLock *sync.Mutex, readerExitChan chan bool) {
	var (
		packet       *proto.Packet
		err          error
		transID      uint16
		statusLocked bool
	)
	defer func() {
		errI := recover()
		if errI != nil {
			packetResultChan <- PacketResult{
				err:     fmt.Errorf("System error: %v", errI),
				errCode: 400,
				Tsid:    transID,
			}
		} else if err != nil {
			packetResultChan <- PacketResult{
				err:     fmt.Errorf("client data read error: %v", err),
				errCode: 401,
				Tsid:    transID,
			}
		}
		close(packetResultChan)
		if statusLocked {
			readerProcessingLock.Unlock()
			statusLocked = false
		}
	}()

	LOOPClientR:
	for {
		if err != nil {
			logger.Warnf("client connection broken: %v", err)
			break
		}

		select {
		case <-readerExitChan:
			break LOOPClientR
		default:
		}

		packetResult := PacketResult{}
		var (
			count   int
			lenBuf  []byte
			tsbuf   []byte
			cmdBuf  []byte
			dataBuf []byte
		)
		lenBuf = make([]byte, proto.HeadLength)

		// 解析长度，4字节
		if count, err = io.ReadFull(client.Reader, lenBuf); err != nil || count != proto.HeadLength {
			if err != nil {
				if err == io.EOF {
					client.Close()
				}
				logger.Debugf("client connection read error where read data length: %v", err)
				break
			} else {
				logger.Warnf("DataLength data length is [%v] less than [%v].", count, proto.HeadLength)
				packetResult.err = ErrParsePacketLength
				packetResult.errCode = 11
				packetResultChan <- packetResult
			}
			break
		}
		// 如果已经读取到包的开头，则应当让请求被处理完.
		readerProcessingLock.Lock()
		statusLocked = true
		packetLen := binary.BigEndian.Uint32(lenBuf)
		// TODO: 版本号暂时无用
		verB := make([]byte, proto.VerLength)
		count, err = io.ReadFull(client.Reader, verB)
		if err != nil {
			err = fmt.Errorf("client connection read error where read protocol ver： %v", err)
			break
		} else if count != proto.VerLength {
			packetResult.err = err
			packetResult.errCode = 16
			packetResultChan <- packetResult
			break
		} else {
			ver := uint8(verB[0])
			logger.Debugf("Got client protocol version %v", ver)
		}
		tsbuf = make([]byte, proto.TsidTypeLen)
		// 解析事务号和消息类型，2字节
		if count, err = io.ReadFull(client.Reader, tsbuf); err != nil || count < proto.TsidTypeLen {
			if err != nil {
				logger.Warnf("client connection read error where read transaction type and no.: %v", err)
			} else {
				logger.Warnf("TransactionTypeAndId data length is [%v] less than [%v].", count, proto.TsidTypeLen)
				packetResult.err = ErrParsePacketTransactionIdAndType
				packetResult.errCode = 12
				packetResultChan <- packetResult
			}
			break
		}
		tsidAndType := binary.BigEndian.Uint16(tsbuf)
		packetType := proto.PacketType(tsidAndType & 3)
		transID = tsidAndType >> 2
		packetResult.Tsid = transID

		if packetLen > Conf.MaxProtoSize {
			packetResult.err = ErrPacketLenExcess
			packetResult.errCode = 13
			logger.Warnf("Request packet length [%v] exceeded max [%v].", packetLen, Conf.MaxProtoSize)
			packetResultChan <- packetResult
			break
		}

		packet = &proto.Packet{Type: packetType}

		// 处理Request包
		if packetType == proto.PacketTypeRequest {
			// 解析cmd，\n结束
			if cmdBuf, err = client.Reader.ReadBytes('\n'); err != nil {
				logger.Warnf("client connection read error where read cmd content: %v", err)
				break
			}
			cmd := string(cmdBuf[:len(cmdBuf)-1])
			// 解析数据部分
			dataLen := packetLen - uint32(len(cmdBuf))
			dataBuf = make([]byte, dataLen)

			if count, err = io.ReadFull(client.Reader, dataBuf); err != nil {
				logger.Warnf("client connection read error where read request data: %v", err)
				break
			} else if count != int(dataLen) {
				packetResult.err = fmt.Errorf("Data read length[%v] is less than expected[%v].", count, dataLen)
				packetResult.errCode = 15
				packetResultChan <- packetResult
				break
			}
			dataBuf, err = proto.Decrypt(dataBuf, Conf.DataEncryptKey, Conf.DataEncryptIV)
			if err != nil {
				packetResult.err = fmt.Errorf("Failed to decrypt request data for transaction[%v]: %v", transID, err)
				packetResult.errCode = 10
				packetResultChan <- packetResult
				break
			}
			packet.Req = proto.NewRequest(transID, cmd, dataBuf, Conf.DataEncryptKey, Conf.DataEncryptIV)
		} else if packetType == proto.PacketTypeResponse {
			// 处理Response包
			dataBuf = make([]byte, packetLen)
			if count, err = io.ReadFull(client.Reader, dataBuf); err != nil {
				logger.Warnf("client connection read error where read transaction data: %v", err)
				break
			} else if count != int(packetLen) {
				packetResult.err = fmt.Errorf("Data read length[%v] is less than expected[%v].", count, packetLen)
				packetResult.errCode = 15
				packetResultChan <- packetResult
				break
			}

			dataBuf, err = proto.Decrypt(dataBuf, Conf.DataEncryptKey, Conf.DataEncryptIV)
			if err != nil {
				packetResult.err = fmt.Errorf("Failed to decrypt response data for transaction[%v]: %v", transID, err)
				packetResult.errCode = 10
				packetResultChan <- packetResult
				break
			}
			packet.Resp = proto.NewResponse(transID, dataBuf, Conf.DataEncryptKey, Conf.DataEncryptIV)
		} else if packetType == proto.PacketTypeProtoError {
			// 处理Error包
			dataBuf = make([]byte, packetLen)
			if count, err = io.ReadFull(client.Reader, dataBuf); err != nil {
				logger.Warnf("client connection read error where read transaction data: %v", err)
				break
			} else if count != int(packetLen) {
				packetResult.err = fmt.Errorf("Data read length[%v] is less than expected[%v].", count, packetLen)
				packetResult.errCode = 15
				packetResultChan <- packetResult
				break
			}

			dataBuf, err = proto.Decrypt(dataBuf, Conf.DataEncryptKey, Conf.DataEncryptIV)
			if err != nil {
				packetResult.err = fmt.Errorf("Failed to decrypt response error data for transaction[%v]: %v", transID, err)
				packetResult.errCode = 10
				packetResultChan <- packetResult
				break
			}
			packet.ProtoErr = proto.NewProtoError(transID, dataBuf, Conf.DataEncryptKey, Conf.DataEncryptIV)
		} else {
			packetResult.err = ErrPacketInvalidType
			packetResultChan <- packetResult
			break
		}
		statusLocked = false
		readerProcessingLock.Unlock()
		packetResult.Packet = packet
		packetResultChan <- packetResult
	}
}

func processPacketResult(client *acsclient.Client, packetResult *PacketResult) {
	switch packetResult.Packet.Type {
	case proto.PacketTypeRequest:
		go client.HandleRequest(packetResult.Packet.Req)
	case proto.PacketTypeResponse:
		go client.HandleResponse(packetResult.Packet.Resp)
	case proto.PacketTypeProtoError:
		// 处理对端返回的错误信息
		// TODO: 请求发起者(routine)知道这个错误
		go client.HandleRemoteProtoError(packetResult.Packet.ProtoErr)
	default:
		logger.Warnf("un-handlable packet type: %v", packetResult.Packet.Type)
	}
}
